import numpy as np
import matplotlib.pyplot as plt
from SVMAgent import SVMAgent, FastComNetwork, SVMOracle, DatasetModel, DGFMAgent, ComNetwork, DGFMplus
import random

# Hyperparameter configuration
NUM_AGENTS = 16
NUM_ROUNDS = 2000
T_RESTART = 100
DELTA = 0.001
LR = 0.01  # DOC²S learning rate
D = 0.01  # Online learning bound
m_LR = 0.01  # MEDOL learning rate
m_D = 0.001  # MEDOL online learning bound
DGFM_LR1 = 0.001  # DGFM learning rate
DGFM_LR2 = 0.001  # DGFM+ learning rate
R = 1  # Chebyshev acceleration rounds
p = 0.99  # Matrix diagonal function value
BATCH_SIZE = 128
DATASET_NAME = 'rcv'
random.seed(42)
np.random.seed(42)

# Initialize dataset
dataset = DatasetModel(dsname=DATASET_NAME, num_agent=NUM_AGENTS, mb_size=BATCH_SIZE)
oracle = SVMOracle(alpha=2, lam=1e-5)

# Ring matrix
def ring_matrix(n, p):  # n represents the number of nodes, p represents the element on the diagonal
    k = 1  # Fixed k = 1
    assert 2 * k < n - 1, 'k must be smaller than n/2'
    W = np.zeros((n, n))

    for m in range(n):
        for j in range(m - k, m + k + 1):
            if j < 0:
                j = j + n
            if j >= n:
                j = j - n
            if m == j:
                W[m, j] = p  # Connection probability on the diagonal is p
            elif abs(m - j) == 1 or abs(m - j) == n - 1:
                W[m, j] = (1 - p) / 2  # Connection probability for adjacent nodes is (1 - p) / 2
    return W

def create_matrix(n):
    return np.full((n, n), 1 / n)

# DOC²S training function
def train_DOC2S(agents, num_rounds, t_restart, oracle_type):
    network = FastComNetwork(create_matrix(NUM_AGENTS))
    losses = []
    max_fun = []
    for k in range(num_rounds):  # Changed to a loop with a counter
        # Periodically reset agent actions
        if k % t_restart == (t_restart - 1):
            for agent in agents:
                agent.initialize_action()

        # Client sampling
        selected = np.random.randint(NUM_AGENTS)  # Client sampling

        # Obtain gradient estimate
        x_mb, y_mb = dataset.get_sample(selected)
        for m in range(NUM_AGENTS):
            agents[m].get_grad_point()
        new_weight = agents[selected].DOC2S_get_new_weight()
        agents[selected].set_weight(new_weight)

        grad_point = agents[selected].get_grad_points()

        if oracle_type == '1st':
            grad = oracle.get_gradients(grad_point, x_mb, y_mb)
        else:
            # grad = oracle.get_zo_grad(grad_point, x_mb, y_mb, delta=DELTA)
            grad = oracle.get_gradients(grad_point, x_mb, y_mb)

        # Process Delta update for all clients
        for i, agent in enumerate(agents):
            if i == selected:
                # For selected client: apply update with projection and multiply by n
                unprojected = agent.get_action() - agent.lr * grad
                norm = np.linalg.norm(unprojected)
                scale = min(1, agent.D / norm) if norm > 1e-8 else 1.0
                agent.set_action(agent.NUM_AGENTS * scale * unprojected)
            else:
                # For unselected clients: set Delta to 0
                agent.set_action(np.zeros_like(agent.get_action()))

        # Chebyshev accelerated communication
        network.propagate_actions(agents, R)
        network.propagate_weights(agents, R)

        # Record loss
        avg_w = network.get_average_weight(agents)
        losses.append(oracle.get_fn_val(avg_w, *dataset.get_test_set()))

        # Record maximum loss
        agent_wi = []
        for agent in agents:
            wi = agent.get_weight()
            agent_wi.append(wi)
        agent_losses = []
        for idx, weight in enumerate(agent_wi):
            loss = oracle.get_fn_val(weight, *dataset.get_test_set())
            agent_losses.append(loss)
        max_loss = max(agent_losses)
        max_fun.append(max_loss)

    return losses

# MEDOL training function
def train_MEDOL(agents, num_rounds, t_restart, oracle_type):
    network = ComNetwork(ring_matrix(NUM_AGENTS, p))
    losses = []
    max_fun = []
    for k in range(num_rounds):  # Changed to a loop with a counter
        # Periodically reset agent actions
        if k % t_restart == (t_restart - 1):
            for agent in agents:
                agent.initialize_action()
                agent.get_weight()

        for m in range(NUM_AGENTS):
            agents[m].get_grad_point()
            new_weight = agents[m].get_new_weight()
            agents[m].set_weight(new_weight)
            grad_point = agents[m].get_grad_points()

            x1, y1 = dataset.get_sample(m)
            if oracle_type == '1st':
                grad = oracle.get_gradients(grad_point, x1, y1)
            else:
                # grad = oracle.get_zo_grad(grad_point, x1, y1, delta=DELTA)
                grad = oracle.get_gradients(grad_point, x1, y1)

            agents[m].action_grad_update(grad)

        # Chebyshev accelerated communication
        network.propagate_actions(agents)
        network.propagate_weights(agents)

        # Record loss
        avg_w = network.get_average_weight(agents)
        losses.append(oracle.get_fn_val(avg_w, *dataset.get_test_set()))

        # Record maximum loss
        agent_wi = []
        for agent in agents:
            wi = agent.get_weight()
            agent_wi.append(wi)
        agent_losses = []
        for idx, weight in enumerate(agent_wi):
            loss = oracle.get_fn_val(weight, *dataset.get_test_set())
            agent_losses.append(loss)
        max_loss = max(agent_losses)
        max_fun.append(max_loss)

    return losses

# DGFM training function
def train_DGFM(agents, num_rounds):
    network = ComNetwork(ring_matrix(NUM_AGENTS, p))
    losses = []
    max_fun = []

    for _ in range(num_rounds):
        # Update all clients
        for agent in agents:
            x_mb, y_mb = dataset.get_sample(agent.id)
            w = agent.get_weight()
            grad = oracle.get_zo_grad(w, x_mb, y_mb, delta=DELTA)
            agent.update_y_grad(grad)

        # Communication
        network.propagate_dgfm_grad(agents)

        for agent in agents:
            agent.update_weight()

        network.propagate_weights(agents)

        # Record loss
        avg_weight = network.get_average_weight(agents)
        losses.append(oracle.get_fn_val(avg_weight, *dataset.get_test_set()))

        # Record maximum loss
        agent_wi = []
        for agent in agents:
            wi = agent.get_weight()
            agent_wi.append(wi)
        agent_losses = []
        for idx, weight in enumerate(agent_wi):
            loss = oracle.get_fn_val(weight, *dataset.get_test_set())
            agent_losses.append(loss)
        max_loss = max(agent_losses)
        max_fun.append(max_loss)

    return max_fun

# DGFM+ training function
def train_DGFM_plus(agents, num_rounds, T_restart=10, mega_batch=512):
    network = FastComNetwork(ring_matrix(NUM_AGENTS, p))
    losses = []
    max_fun = []

    T = T_restart  # Restart period
    b = 64  # Regular batch size
    b_prime = 256  # Large batch size during restart

    for k in range(num_rounds):
        if k % T == 0:
            # === Restart phase ===
            # 1. Calculate initial gradient tracking variable v_i^{rT} = g_i(x_i^{rT}; S_i^{rT})
            for agent in agents:
                x_mb, y_mb = dataset.get_sample_DGFM(agent.id, mb_size=b_prime)
                grad = oracle.get_zo_grad(agent.get_weight(), x_mb, y_mb, delta=DELTA)
                agent.set_v(grad)  # Directly set v_i^{rT}

            # 2. Multiple rounds of Chebyshev accelerated communication (R=5)
            network.propagate_v(agents, R=5)

            # 3. Save current gradient to prev_grad
            for agent in agents:
                agent.save_prev_grad()

        # === Regular iteration ===
        # Update all clients
        for agent in agents:
            x_mb, y_mb = dataset.get_sample_DGFM(agent.id, mb_size=b)
            grad_new = oracle.get_zo_grad(agent.get_weight(), x_mb, y_mb, delta=DELTA)
            agent.update_spider_grad(grad_new)  # Perform SPIDER update

        # Single round of communication to propagate gradient tracking variable v
        network.propagate_v(agents, R=1)

        # Weight update and communication
        for agent in agents:
            agent.update_weight()
        network.propagate_weights(agents, R=1)

        # Record loss
        avg_weight = network.get_average_weight(agents)
        losses.append(oracle.get_fn_val(avg_weight, *dataset.get_test_set()))

        agent_losses = [oracle.get_fn_val(agent.get_weight(), *dataset.get_test_set())
                        for agent in agents]
        max_fun.append(max(agent_losses))

    return losses

# Initialize agents
doc2s_agents = [SVMAgent(dataset.num_param, id=i, lr=LR, D=D, NUM_AGENTS=NUM_AGENTS) for i in range(NUM_AGENTS)]
medol_agents = [SVMAgent(dataset.num_param, id=i, lr=m_LR, D=m_D, NUM_AGENTS=NUM_AGENTS) for i in range(NUM_AGENTS)]
dgfm_agents = [DGFMAgent(dataset.num_param, id=i, lr=DGFM_LR1) for i in range(NUM_AGENTS)]
dgfmp_agents = [DGFMAgent(dataset.num_param, id=i, lr=DGFM_LR2) for i in range(NUM_AGENTS)]

# Training results
doc2s_loss = train_DOC2S(doc2s_agents, NUM_ROUNDS, T_RESTART, '0th')
medol_loss = train_MEDOL(medol_agents, NUM_ROUNDS, T_RESTART, '0th')
dgfm_loss = train_DGFM(dgfm_agents, NUM_ROUNDS)
dgfmp_loss = train_DGFM(dgfmp_agents, NUM_ROUNDS)

step = 100
# Key modification 1: x-axis uses actual rounds [100, 200, ...]
x = np.arange(0, NUM_ROUNDS, step)  # Generate array [0, 100, 200, ...]
doc2s_loss = np.array(doc2s_loss)  # Convert to NumPy array
medol_loss = np.array(medol_loss)  # Convert to NumPy array
DGFM_loss = np.array(dgfm_loss)  # Convert to NumPy array
DGFMp_loss = np.array(dgfmp_loss)  # Convert to NumPy array

doc2s_loss_sampled = doc2s_loss[x]
medol_loss_sampled = medol_loss[x]
dgfm_loss_sampled = DGFM_loss[x]
dgfmp_loss_sampled = DGFMp_loss[x]

# Set up canvas
plt.figure(figsize=(9, 8))

# Plot curves
# Key modification 2: plot x-axis uses actual rounds x instead of default index
plt.plot(x, doc2s_loss_sampled, label='DOC²S', color='black', marker='.',
         markersize=11, linewidth=1.5)
plt.plot(x, medol_loss_sampled, label='MEDOL', color='red', linestyle='--',
         marker='^', markersize=11, linewidth=1.5)
plt.plot(x, dgfm_loss_sampled, label='DGFM', color='blue', linestyle='-.', marker='s',
         markersize=11, linewidth=1.5)
plt.plot(x, dgfmp_loss_sampled, label='DGFM+', color='green',
         marker='^', markersize=11, linewidth=1.5)

# Add legend and labels
plt.legend()
plt.xlabel(r"$\mathrm{Computation~rounds}$", fontsize=31)
plt.ylabel(r"$\mathrm{Function~value}$", fontsize=31)

# Key modification 3: Adjust axis range and ticks
plt.xlim(0, NUM_ROUNDS + 30)  # Set x-axis range
plt.xticks(np.arange(0, NUM_ROUNDS, step * 3),  # Display a tick every 200
           fontsize=20)
plt.yticks(fontsize=20)

# Optimize legend style
plt.legend(fontsize=27, framealpha=0.9)

# Add grid lines
plt.grid(True, linestyle='--', alpha=0.3)

# Save as PDF
base_dir = "D:/Desktop Files/Experiments"
file_name = DATASET_NAME  # Can be changed to any filename

plt.savefig(f"{base_dir}/SVM/{file_name}/0th/computation.pdf",
            bbox_inches='tight',
            dpi=400,
            facecolor='white')

plt.tight_layout()
plt.show()